/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include "../../../common/data_utils.h"
#include "graph/tensor.h"
#include <vector>
#include <cstring>
#ifndef ASCENDC_CPU_DEBUG
#include "acl/acl.h"
extern void broadcast_custom_do(uint32_t coreDim, void *l2ctrl, void *stream, uint8_t *x, uint8_t *y, uint8_t *workspace, uint8_t *tiling);
#else
#include "tikicpulib.h"
extern "C" __global__ __aicore__ void broadcast_custom(GM_ADDR x, GM_ADDR y, GM_ADDR workspace, GM_ADDR tiling);
#endif

constexpr uint32_t TILINGDATA_SIZE = 6;

int64_t CompareResult(void* outputData, int64_t outSize)
{
    void* goldenData;
#ifdef ASCENDC_CPU_DEBUG
    goldenData = (uint8_t*)AscendC::GmAlloc(outSize);
#else
    CHECK_ACL(aclrtMallocHost((void**)(&goldenData), outSize));
#endif
    size_t goldenSize = outSize;
    bool ret = ReadFile("../output/golden.bin", goldenSize, goldenData, goldenSize);
    if (ret) {
        printf("ReadFile golden success!\n");
    } else {
#ifdef ASCENDC_CPU_DEBUG
        AscendC::GmFree((void *)goldenData);
#else
        CHECK_ACL(aclrtFreeHost(goldenData));
#endif
        return -1;
    }
    constexpr float EPS = 1e-5;
    int64_t wrongNum = 0;

    for (int i = 0; i < outSize / sizeof(float); i++) {
        float a = ((float*)outputData)[i];
        float b = ((float*)goldenData)[i];
        float ae = std::abs(a - b);
        float re = ae / abs(b);
        if(ae > EPS && re > EPS) {
            printf("CompareResult failed output is %lf, golden is %lf\n",a,b);
            wrongNum++;
        }
    }
#ifdef ASCENDC_CPU_DEBUG
    AscendC::GmFree((void*)goldenData);
#else
    CHECK_ACL(aclrtFreeHost(goldenData));
#endif
    return wrongNum;
}

uint8_t* GenerateTiling(const ge::Shape &inputShape, const ge::Shape &outputShape, uint32_t dtypeSize);

int32_t main(int32_t argc, char *argv[])
{
    uint32_t blockDim = 1;
    size_t inputSize = 48 * sizeof(float);
    size_t outputSize = 96 * 48 * sizeof(float);
    std::vector<int64_t> inputDims = {1, 48};
    std::vector<int64_t> outputDims = {96, 48};
    if(argc == 2) {
        if(!strcmp(argv[1], "1")) {
            inputSize = 96 * sizeof(float);
            outputSize = 96 * 96 * sizeof(float);
            inputDims = {96, 1};
            outputDims = {96, 96};
        }
    }
    ge::Shape inputShape(inputDims);
    ge::Shape outputShape(outputDims);
    size_t tilingSize = TILINGDATA_SIZE * sizeof(uint32_t); // tilingData size , defined in broadcast_custom_tiling.h
    int64_t wrongNum = -1;
    uint8_t *tilingBuf = GenerateTiling(inputShape, outputShape, sizeof(float));

#ifdef ASCENDC_CPU_DEBUG
    uint8_t *x = (uint8_t *)AscendC::GmAlloc(inputSize);
    uint8_t *y = (uint8_t *)AscendC::GmAlloc(outputSize);
    uint8_t* tiling = (uint8_t*)AscendC::GmAlloc(tilingSize);

    ReadFile("../input/input.bin", inputSize, x, inputSize);
    memcpy_s(tiling, tilingSize, tilingBuf, tilingSize);
    AscendC::SetKernelMode(KernelMode::AIV_MODE);
    ICPU_RUN_KF(broadcast_custom, blockDim, x, y, nullptr, tiling);  // use this macro for cpu debug

    WriteFile("../output/output.bin", y, outputSize);

    wrongNum = CompareResult(y, outputSize);

    AscendC::GmFree((void *)x);
    AscendC::GmFree((void *)tiling);
    AscendC::GmFree((void *)y);
#else
    CHECK_ACL(aclInit(nullptr));
    aclrtContext context;
    int32_t deviceId = 0;
    CHECK_ACL(aclrtSetDevice(deviceId));
    CHECK_ACL(aclrtCreateContext(&context, deviceId));
    aclrtStream stream = nullptr;
    CHECK_ACL(aclrtCreateStream(&stream));

    uint8_t *xHost, *yHost, *tilingHost;
    uint8_t *xDevice, *yDevice, *tilingDevice;

    CHECK_ACL(aclrtMallocHost((void **)(&xHost), inputSize));
    CHECK_ACL(aclrtMallocHost((void **)(&yHost), outputSize));
    CHECK_ACL(aclrtMallocHost((void**)(&tilingHost), tilingSize));
    CHECK_ACL(aclrtMalloc((void **)&xDevice, inputSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void **)&yDevice, outputSize, ACL_MEM_MALLOC_HUGE_FIRST));
    CHECK_ACL(aclrtMalloc((void**)&tilingDevice, tilingSize, ACL_MEM_MALLOC_HUGE_FIRST));

    ReadFile("../input/input.bin", inputSize, xHost, inputSize);

    CHECK_ACL(aclrtMemcpy(xDevice, inputSize, xHost, inputSize, ACL_MEMCPY_HOST_TO_DEVICE));
    CHECK_ACL(aclrtMemcpy(tilingDevice, tilingSize, tilingBuf, tilingSize, ACL_MEMCPY_HOST_TO_DEVICE));

    broadcast_custom_do(blockDim, nullptr, stream, xDevice, yDevice, nullptr, tilingDevice);
    CHECK_ACL(aclrtSynchronizeStream(stream));

    CHECK_ACL(aclrtMemcpy(yHost, outputSize, yDevice, outputSize, ACL_MEMCPY_DEVICE_TO_HOST));
    WriteFile("../output/output.bin", yHost, outputSize);

    wrongNum = CompareResult(yHost, outputSize);

    CHECK_ACL(aclrtFree(xDevice));
    CHECK_ACL(aclrtFree(yDevice));
    CHECK_ACL(aclrtFree(tilingDevice));
    CHECK_ACL(aclrtFreeHost(xHost));
    CHECK_ACL(aclrtFreeHost(yHost));
    CHECK_ACL(aclrtFreeHost(tilingHost));

    CHECK_ACL(aclrtDestroyStream(stream));
    CHECK_ACL(aclrtDestroyContext(context));
    CHECK_ACL(aclrtResetDevice(deviceId));
    CHECK_ACL(aclFinalize());
#endif
    free(tilingBuf);
    if (wrongNum != 0) {
        printf("test failed!\n");
    } else {
        printf("test pass!\n");
    }
    return 0;
}